from torch import nn

class weighted_mae(nn.Module):
    def __init__(self, weights_prec, thresholds_prec):

        super(weighted_mae, self).__init__()
        self.weighted_prec = weighted_mae(weights=weights_prec,
                                     thresholds=thresholds_prec)

    def forward(self, output, target):
        """
        :param input: nbatchs * nlengths * 1 * nheigths * nwidths
        :param target: nbatchs * nlengths * 1 * nheigths * nwidths
        :return:
        """
        loss = self.weighted_prec(output[:,:,0,:,:], target[:,:,0,:,:])
        return loss

class weighted_mae(nn.Module):
    def __init__(self, weights=(0.5, 1, 2.5, 5, 10, 20),
                 thresholds=(0.1, 0.2, 0.5, 1, 2),):

        super(weighted_mae, self).__init__()
        assert len(thresholds) + 1 == len(weights)
        self.weights = weights
        self.threholds = thresholds

    def forward(self, predict, target):
        """
        :param input: nbatchs * nlengths * nheigths * nwidths
        :param target: nbatchs * nlengths * nheigths * nwidths
        :return:
        """
        balance_weights = torch.zeros_like(target)
        balance_weights[target<self.threholds[0]] = self.weights[0]
        for i, _ in enumerate(self.threholds[:-1]):
            balance_weights[(target >= self.threholds[i]) & (target < self.threholds[i + 1])] = self.weights[i + 1]
        balance_weights[target >= self.threholds[-1]] = self.weights[-1]
        mae = torch.mean(balance_weights * (torch.abs(predict-target)))
        return mae